/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef EXAMPLES_ACTIVATION_SOFTMAXFLASH_KERNEL_H
#define EXAMPLES_ACTIVATION_SOFTMAXFLASH_KERNEL_H
#include "kernel_operator.h"
#include "softmaxflash_custom_tiling.h"

namespace MyCustomKernel {
constexpr int32_t BUFFER_NUM = 2;
constexpr uint32_t FLOAT_NUM_OF_SINGEL_BLOCK = 8;
constexpr uint32_t BASIC_BLOCK_ROW_FACTOR = 8;
constexpr uint32_t BASIC_BLOCK_COLUMN_FACTOR = 64;
constexpr uint32_t BASIC_BLOCK_MAX_COLUMN_LENGTH = 2048;

class KernelSoftmax {
public:
    __aicore__ inline KernelSoftmax() {}
    __aicore__ inline void InitTiling(const SoftmaxflashCustomTilingData& tilingData)
    {
        rowLength = tilingData.rowLength;
        sharedTmpBufferSize = tilingData.sharedTmpBufferSize;
        columnLength = tilingData.columnLength;
        usedBlockDim = tilingData.usedBlockDim;
        coreRowNum = tilingData.coreRowNum;
        softmaxTiling = tilingData.softmaxTilingData;
        singleLoopCoreRowNum = tilingData.singleLoopCoreRowNum;
        singleCoreLoopCount = tilingData.singleCoreLoopCount;
        leftRow = tilingData.singleCoreLoopTail;
        tailCoreSingleLoopCoreRowNum = tilingData.tailCoreSingleLoopCoreRowNum;
        tailCoreSingleCoreLoopCount = tilingData.tailCoreSingleCoreLoopCount;
        tailCoreSingleCoreLoopTail = tilingData.tailCoreSingleCoreLoopTail;
        splitK = tilingData.splitK;
        loopK = tilingData.loopK;
        tailK = tilingData.tailK;
    }

    __aicore__ inline void Init(GM_ADDR x, GM_ADDR max, GM_ADDR sum, const SoftmaxflashCustomTilingData& tilingData)
    {
        ASSERT(AscendC::GetBlockNum() != 0 && "block dim can not be zero!");
        InitTiling(tilingData);

        if (AscendC::GetBlockIdx() == this->usedBlockDim) { // tail core
            this->singleLoopCoreRowNum = this->tailCoreSingleLoopCoreRowNum;
            this->singleCoreLoopCount = this->tailCoreSingleCoreLoopCount;
            this->leftRow = this->tailCoreSingleCoreLoopTail;
        }

        this->blockLength = this->coreRowNum * this->columnLength;
        this->msLength = this->coreRowNum * FLOAT_NUM_OF_SINGEL_BLOCK; // max sum length per block process

        uint32_t offset1 = this->blockLength * AscendC::GetBlockIdx();
        uint32_t offset2 = this->msLength * AscendC::GetBlockIdx();

        xGm.SetGlobalBuffer((__gm__ float*)x + offset1, this->blockLength);
        maxGm.SetGlobalBuffer((__gm__ float*)max + offset2, this->msLength);
        sumGm.SetGlobalBuffer((__gm__ float*)sum + offset2, this->msLength);

        this->tileLength = this->singleLoopCoreRowNum * splitK;
        pipe.InitBuffer(queueX, BUFFER_NUM, this->tileLength * sizeof(float));

        this->msTileLength = this->singleLoopCoreRowNum * FLOAT_NUM_OF_SINGEL_BLOCK;
        pipe.InitBuffer(tbufMax, this->msTileLength * sizeof(float));
        pipe.InitBuffer(tbufSum, this->msTileLength * sizeof(float));
        pipe.InitBuffer(tbufExpmax, this->msTileLength * sizeof(float));
        maxLocal = tbufMax.Get<float>();
        sumLocal = tbufSum.Get<float>();
        expmaxLocal = tbufExpmax.Get<float>();

        pipe.InitBuffer(sharedTmpBuffer, sharedTmpBufferSize);
    }

    __aicore__ inline void Process()
    {
        if (AscendC::GetBlockIdx() > this->usedBlockDim) {
            return;
        }

        for (int32_t i = 0; i < this->singleCoreLoopCount; i++) { // split M
            for (int32_t j = 0; j < this->loopK; j++) { // split K
                CopyIn(i, j, this->singleLoopCoreRowNum, this->splitK);
                Compute(i, j, this->singleLoopCoreRowNum, this->splitK);
            }

            if (this->tailK > 0) {
                CopyIn(i, this->loopK, this->singleLoopCoreRowNum, this->tailK);
                Compute(i, this->loopK, this->singleLoopCoreRowNum, this->tailK);
            }
            event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(AscendC::HardEvent::V_MTE3));
            AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventIdVToMte3);
            AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventIdVToMte3);
            // copy max sum to gm
            CopyOut(i, this->msTileLength);
        }
        event_t eventIdMte3ToV = static_cast<event_t>(GetTPipePtr()->FetchEventID(AscendC::HardEvent::MTE3_V));
        AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventIdMte3ToV);
        AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventIdMte3ToV);

        if (this->leftRow > 0) {
            for (int32_t j = 0; j < this->loopK; j++) { // split K
                CopyIn(this->singleCoreLoopCount, j, this->leftRow, this->splitK);
                Compute(this->singleCoreLoopCount, j, this->leftRow, this->splitK);
            }

            if (this->tailK > 0) {
                CopyIn(this->singleCoreLoopCount, this->loopK, this->leftRow, this->tailK);
                Compute(this->singleCoreLoopCount, this->loopK, this->leftRow, this->tailK);
            }
            event_t eventIdVToMte3 = static_cast<event_t>(GetTPipePtr()->FetchEventID(AscendC::HardEvent::V_MTE3));
            AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventIdVToMte3);
            AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventIdVToMte3);
            // copy max sum to gm
            uint32_t tailMsTileLength = this->leftRow * FLOAT_NUM_OF_SINGEL_BLOCK;
            CopyOut(this->singleCoreLoopCount, tailMsTileLength);
        }
    }

private:
    __aicore__ inline void CopyIn(uint32_t rowIndex, uint32_t kIndex, uint32_t rowNum, uint32_t columnNum)
    {
        AscendC::LocalTensor<float> xLocal = queueX.AllocTensor<float>();
        uint32_t offset = this->singleLoopCoreRowNum * this->columnLength;
        for (uint32_t i = 0; i < rowNum; i++) {
            AscendC::DataCopy(xLocal[i * columnNum], xGm[rowIndex * offset + i * this->columnLength + kIndex * this->splitK],
                columnNum);
        }
        queueX.EnQue(xLocal);
    }

    __aicore__ inline void Compute(uint32_t rowIndex, uint32_t columnIndex, uint32_t rowNum, uint32_t columnNum)
    {
        AscendC::LocalTensor<float> xLocal = queueX.DeQue<float>();
        AscendC::LocalTensor<uint8_t> tmpBuffer = sharedTmpBuffer.Get<uint8_t>();

        AscendC::SoftMaxShapeInfo srcShape = { rowNum, columnNum, rowNum, columnNum };
        if (columnIndex == 0) {                           // isUpdate == false
            if (rowNum % BASIC_BLOCK_ROW_FACTOR == 0 &&
                columnNum % BASIC_BLOCK_COLUMN_FACTOR == 0 &&
                columnNum < BASIC_BLOCK_MAX_COLUMN_LENGTH) {
                AscendC::SoftmaxFlashV2<float, false, true, true>(xLocal, sumLocal, maxLocal, xLocal, expmaxLocal, sumLocal,
                    maxLocal, tmpBuffer, softmaxTiling, srcShape);
            } else {
                AscendC::SoftmaxFlashV2<float, false, true, false>(xLocal, sumLocal, maxLocal, xLocal, expmaxLocal, sumLocal,
                    maxLocal, tmpBuffer, softmaxTiling, srcShape);
            }
        } else {
            if (rowNum % BASIC_BLOCK_ROW_FACTOR == 0 &&
                columnNum % BASIC_BLOCK_COLUMN_FACTOR == 0 &&
                columnNum < BASIC_BLOCK_MAX_COLUMN_LENGTH) {
                AscendC::SoftmaxFlashV2<float, true, true, true>(xLocal, sumLocal, maxLocal, xLocal, expmaxLocal, sumLocal,
                    maxLocal, tmpBuffer, softmaxTiling, srcShape);
            } else {
                AscendC::SoftmaxFlashV2<float, true, true, false>(xLocal, sumLocal, maxLocal, xLocal, expmaxLocal, sumLocal,
                    maxLocal, tmpBuffer, softmaxTiling, srcShape);
            }
        }
        queueX.FreeTensor(xLocal);
    }

    __aicore__ inline void CopyOut(uint32_t progress, uint32_t count)
    {
        AscendC::DataCopy(maxGm[progress * this->msTileLength], maxLocal, count);
        AscendC::DataCopy(sumGm[progress * this->msTileLength], sumLocal, count);
    }

private:
    AscendC::TPipe pipe;
    AscendC::TBuf<AscendC::TPosition::VECCALC> sharedTmpBuffer;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> queueX;
    AscendC::TBuf<AscendC::TPosition::VECCALC> tbufMax, tbufSum, tbufExpmax;
    AscendC::GlobalTensor<float> xGm;
    AscendC::GlobalTensor<float> maxGm;
    AscendC::GlobalTensor<float> sumGm;
    AscendC::LocalTensor<float> maxLocal;
    AscendC::LocalTensor<float> sumLocal;
    AscendC::LocalTensor<float> expmaxLocal;

    uint32_t blockLength = 0;
    uint32_t usedBlockDim = 0;
    uint32_t msLength = 0;
    uint32_t rowLength = 0;
    uint32_t columnLength = 0;
    uint32_t coreRowNum = 0;
    uint32_t tileLength = 0;
    uint32_t msTileLength = 0;
    uint32_t loopCount = 0;
    uint32_t sharedTmpBufferSize = 0;
    uint32_t singleLoopCoreRowNum = 0;
    uint32_t singleCoreLoopCount = 0;
    uint32_t leftRow = 0;
    uint32_t tailCoreSingleLoopCoreRowNum = 0;
    uint32_t tailCoreSingleCoreLoopCount = 0;
    uint32_t tailCoreSingleCoreLoopTail = 0;
    uint32_t splitK = 0;
    uint32_t loopK = 0;
    uint32_t tailK = 0;
    SoftMaxTiling softmaxTiling;
};
}
#endif // EXAMPLES_ACTIVATION_SOFTMAXFLASH_KERNEL_H